# -*-coding:utf-8 -*-
import csv
import hashlib
import json
import os
import random
import string
import traceback
import urllib.parse
import pymysql

from obs import ObsClient
from kafka import KafkaProducer


# 临时文件存储
LOCAL_MOUNT_PATH = '/tmp/'
LETTERS = string.ascii_letters


# 验证环境变量是否配置
def check_configuration(context):
    region = context.getUserData('region')
    if not region:
        return 'region is not configured'

    obs_server = context.getUserData('obs_endpoint')
    if not obs_server:
        return 'obs_server is not configured'

    obs_bucket = context.getUserData('obs_bucket')
    if not obs_bucket:
        return 'obs_bucket is not configured'

    ak = context.getAccessKey().strip()
    sk = context.getSecretKey().strip()
    if not ak or not sk:
        ak = context.getUserData('ak', '').strip()
        sk = context.getUserData('sk', '').strip()
        if not ak or not sk:
            return 'AK or SK is empty'

    kafka_sever = context.getUserData("kafka_sever")
    if not kafka_sever:
        return 'kafka_sever is not configured'

    kafka_topic = context.getUserData("kafka_topic")
    if not kafka_topic:
        return 'kafka_topic is not configured'

    mysql_ipaddress = context.getUserData("mysql_ipaddress")
    if not mysql_ipaddress:
        return 'mysql_ipaddress is not configured'

    password = context.getUserData("password")
    if not password:
        return 'password is not configured'

    database = context.getUserData("database")
    if not database:
        return 'database is not configured'
    return ''


# 函数工作流调用入口
def handler(event, context):
    log = context.getLogger()
    result = check_configuration(context)
    if result:
        return result

    records = event.get("Records", [])
    if not records:
        return 'Records is empty'

    processing = Processing(context)
    try:
        for record in records:
            bucket_name, object_key = get_obs_obj_info(record)
            if object_key.endswith("csv"):
                processing.run(record)
            else:
                log.error("File type error")
    except Exception as e:
        exec_info = traceback.format_exc()
        log.error(f"failed to run extract handler: {exec_info}")
    finally:
        processing.obs_client.close()
    return 'Complete!'


class Processing:

    def __init__(self, context):
        self.logger = context.getLogger()
        obs_endpoint = context.getUserData("obs_endpoint")
        self.obs_client = new_obs_client(context, obs_endpoint)
        self.obs_bucket = context.getUserData("obs_bucket")
        self.download_dir = gen_local_download_path()
        self.kafka_sever = context.getUserData("kafka_sever")
        self.kafka_topic = context.getUserData("kafka_topic")
        self.mysql_ipaddress = context.getUserData("mysql_ipaddress")
        self.password = context.getUserData("password")
        self.database = context.getUserData("database")

    def run(self, record):
        (bucket, object_key) = get_obs_obj_info(record)
        object_key = urllib.parse.unquote_plus(object_key)
        self.logger.info(f'obs bucket: %s', bucket)
        self.logger.info(f'obs object_key: %s', object_key)
        (path, file) = os.path.split(object_key)
        download_path = "%s/%s" % (self.download_dir, file)
        download_result = self.download_file_from_obs(bucket, object_key, download_path)
        if download_result is False:
            return "Fail"
        self.write_data_to_kafka(self.kafka_sever, self.kafka_topic, self.download_dir)
        self.creating_rds_database_tables(self.mysql_ipaddress, self.password, self.database)
        return 'Success'

    # 下载文件
    def download_file_from_obs(self, bucket, obj_name, download_path):
        self.logger.info(f'start to download object %s from obs %s to local %s',
                         obj_name, bucket, download_path)
        try:
            resp = self.obs_client.getObject(bucket, obj_name,
                                             downloadPath=download_path)
            if resp.status < 300:
                self.logger.info(
                    f'succeeded to download object %s from obs %s to local %s',
                    obj_name, bucket, download_path)
                return True
            else:
                self.logger.error(
                    f"failed to download object {obj_name} from obs {bucket}, "
                    f"errorCode:{resp.errorCode} errorMessage:{resp.errorMessage}")
                return False
        except:
            self.logger.error(
                f"failed to download file {obj_name} from obs bucket{bucket}, "
                f"exp:{traceback.format_exc()}")
            return False

    # 数据写入Kafka
    def write_data_to_kafka(self, kafka_sever, kafka_topic, download_dir):
        if not os.listdir(download_dir):
            self.logger.error(f'No files under {download_dir}')
            return
        conf = {
            'bootstrap_servers': kafka_sever.split(","),
            'topic_name': kafka_topic
        }
        self.logger.info("Read files and write to kafka")
        producer = KafkaProducer(bootstrap_servers=conf['bootstrap_servers'])
        files = os.listdir(download_dir)
        for file in files:
            with open(os.path.join(download_dir, file), newline='') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    data = bytes(json.dumps(row), encoding="utf-8")
                    producer.send(conf['topic_name'], data)
        producer.close()
        self.logger.info("Data is successfully written to Kafka.")

    # 数据库建表
    def creating_rds_database_tables(self, mysql_ipaddress, password, database):
        conn = pymysql.connect(
            host=mysql_ipaddress,
            port=3306,
            user='root',
            password=password,
            db=database
        )
        cursor = conn.cursor()
        cursor.execute("SHOW TABLES LIKE 'trade_channel_collect'")
        result = cursor.fetchone()
        if result:
            self.logger.info("Table is exist.")
            return
        sql = '''
            CREATE TABLE `{0}`.`trade_channel_collect` (
                `begin_time` VARCHAR(32) NOT NULL,
                `channel_code` VARCHAR(32) NOT NULL,
                `channel_name` VARCHAR(32) NULL,
                `cur_gmv` DOUBLE UNSIGNED NULL,
                `cur_order_user_count` BIGINT UNSIGNED NULL,
                `cur_order_count` BIGINT UNSIGNED NULL,
                `last_pay_time` VARCHAR(32) NULL,
                `flink_current_time` VARCHAR(32) NULL,
                PRIMARY KEY (`begin_time`, `channel_code`)
                )	ENGINE = InnoDB
                DEFAULT CHARACTER SET = utf8mb4
                COLLATE = utf8mb4_general_ci
                COMMENT = '各渠道的销售总额实时统计';
            '''.format(database)
        cursor.execute(sql)
        conn.commit()
        cursor.close()
        conn.close()
        self.logger.info("Successfully created table")


# 创建OBS连接
def new_obs_client(context, obs_endpoint):
    ak = context.getAccessKey()
    sk = context.getSecretKey()
    return ObsClient(access_key_id=ak, secret_access_key=sk, server=obs_endpoint)


# 文件下载存放路径
def gen_local_download_path():
    download_dir = "%s/%s" % (LOCAL_MOUNT_PATH, ''.join(random.choice(LETTERS) for _ in range(16)))
    if not os.path.exists(download_dir):
        os.makedirs(download_dir)
    return download_dir


def get_obs_obj_info(record):
    s3 = record['s3']
    return s3['bucket']['name'], s3['object']['key']
